{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1e400b63",
   "metadata": {},
   "source": [
    "## 数据集定义与加载"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f974190e",
   "metadata": {},
   "source": [
    "[官网地址](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/02_paddle2.0_develop/02_data_load_cn.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b37f551",
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "029184c7",
   "metadata": {},
   "source": [
    "### 框架自带的数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b00aa089",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "视觉相关的数据集 ['DatasetFolder', 'ImageFolder', 'MNIST', 'FashionMNIST', 'Flowers', 'Cifar10', 'Cifar100', 'VOC2012']\n",
      "文本相关的数据集 ['Conll05st', 'Imdb', 'Imikolov', 'Movielens', 'UCIHousing', 'WMT14', 'WMT16', 'ViterbiDecoder', 'viterbi_decode']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(\"视觉相关的数据集\",paddle.vision.datasets.__all__)\n",
    "print(\"文本相关的数据集\",paddle.text.__all__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6ec0ada7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from paddle.vision.transforms import ToTensor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78587e6c",
   "metadata": {},
   "source": [
    "从远端下载到本机的缓存目录"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c14eaa5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = paddle.vision.datasets.MNIST(mode='train',transform=ToTensor())\n",
    "val_dataset = paddle.vision.datasets.MNIST(mode='test',transform=ToTensor())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d870863",
   "metadata": {},
   "source": [
    "### 自定义数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "18348338",
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "from paddle.io import Dataset\n",
    "\n",
    "BATCH_SIZE = 64\n",
    "BATCH_NUM = 20\n",
    "\n",
    "IMAGE_SIZE = (28,28)\n",
    "CLASS_NUM = 10\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "    \"\"\"\n",
    "    步骤一：继承paddle.io.Dataset类\n",
    "    \"\"\"\n",
    "    def __init__(self,num_samples):\n",
    "        \"\"\"\n",
    "        步骤二：实现构造函数，定义数据集大小\n",
    "        \"\"\"\n",
    "        super(MyDataset,self).__init__()\n",
    "        self.num_samples = num_samples\n",
    "    \n",
    "    def __getitem__(self,index):\n",
    "        \"\"\"\n",
    "        步骤三：重写__getitem__方法，定义指定index时如何获取数据，返回数据及标签\n",
    "        \"\"\"\n",
    "        data = paddle.uniform(IMAGE_SIZE,dtype=\"float32\")\n",
    "        label = paddle.randint(0,CLASS_NUM-1,dtype='int64')\n",
    "        return data,label\n",
    "    \n",
    "    def __len__(self):\n",
    "        \"\"\"\n",
    "        步骤四：重写__len__方法，返回数据集的数目\n",
    "        \"\"\"\n",
    "        return self.num_samples\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "82ace679",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==========custom dataset=============\n",
      "[28, 28] [1]\n"
     ]
    }
   ],
   "source": [
    "custom_dataset = MyDataset(BATCH_SIZE*BATCH_NUM)\n",
    "print(\"==========custom dataset=============\")\n",
    "for data,label in custom_dataset:\n",
    "    print(data.shape,label.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "605661da",
   "metadata": {},
   "source": [
    "### 数据加载\n",
    ">paddle.io.DataLoader 采用异步加载数据的方式读取数据，提升数据加载的速度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fdc53e37",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = paddle.io.DataLoader(custom_dataset,batch_size=BATCH_SIZE,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "e55a83ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[64, 28, 28]\n",
      "[64, 1]\n"
     ]
    }
   ],
   "source": [
    "for batch_id,data in enumerate(train_data):\n",
    "    x_data = data[0]\n",
    "    y_data = data[1]\n",
    "    print(x_data.shape)\n",
    "    print(y_data.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bbec1af",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ddd33511",
   "metadata": {},
   "source": [
    "## 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5fcee0f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数据处理方法： ['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'normalize']\n"
     ]
    }
   ],
   "source": [
    "print(\"数据处理方法：\",paddle.vision.transforms.__all__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dd5ff0b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "165px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
